# ***** BEGIN LICENSE BLOCK *****
# Version: MPL 1.1/GPL 2.0/LGPL 2.1
#
# The contents of this file are subject to the Mozilla Public License Version
# 1.1 (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
# http://www.mozilla.org/MPL/
#
# Software distributed under the License is distributed on an "AS IS" basis,
# WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
# for the specific language governing rights and limitations under the
# License.
#
# The Original Code is mozilla.org code.
#
# Contributor(s):
#   Chris Jones <jones.chris.g@gmail.com>
#
# Alternatively, the contents of this file may be used under the terms of
# either of the GNU General Public License Version 2 or later (the "GPL"),
# or the GNU Lesser General Public License Version 2.1 or later (the "LGPL"),
# in which case the provisions of the GPL or the LGPL are applicable instead
# of those above. If you wish to allow use of your version of this file only
# under the terms of either the GPL or the LGPL, and not to allow others to
# use your version of this file under the terms of the MPL, indicate your
# decision by deleting the provisions above and replace them with the notice
# and other provisions required by the GPL or the LGPL. If you do not delete
# the provisions above, a recipient may use your version of this file under
# the terms of any one of the MPL, the GPL or the LGPL.
#
# ***** END LICENSE BLOCK *****

import StringIO

class Rewriter:
    def __init__(self, name, typerws, callrws):
        self.name = name
        self.typerws = typerws
        self.callrws = callrws
        self._uid = 0

    def uid(self):
        tuid = self._uid
        self._uid += 1
        return tuid

    def gencode(self, out):
        out.write('''
// Automatically generated by porky (http://TODO)
// from the rewrite specification:
/*
%(self)r
*/
#ifndef %(name)s_rules_gen_h
#define %(name)s_rules_gen_h

#define REWRITER %(name)s

// Type rewrite rules
%(typerules)s

// Call rewrite rules
%(callrules)s

#endif // %(name)s_rules_gen_h
''' %{ 'self': self,
       'name': self.name,
       'typerules': self.genTypeRuleDecls(),
       'callrules': self.genCallRuleDecls()
       })

    def genTypeRuleDecls(self):
        s = '\n'.join(
            [ trw.gendecls(self.uid()) for trw in self.typerws ])
        if len(self.typerws) > 0:
            nullsep = ',\n'
        else:
            nullsep = ''
        s += ('\nTypeRule _type_rules[] = { '
              + ',\n'.join([ trw.genrule() for trw in self.typerws ]) 
              + '%s%s };'% (nullsep,
                            TypeRewrite.gennullrule()))
        return s

    def genCallRuleDecls(self):
        s = '\n'.join(
            [ crw.gendecls(self.uid()) for crw in self.callrws ])
        if len(self.callrws) > 0:
            nullsep = ',\n' 
        else:
            nullsep = ''
        s += ('\nCallRule _call_rules[] = {\n'
              + ',\n'.join([ crw.genrule() for crw in self.callrws ])
              + '%s%s };'% (nullsep,
                            CallRewrite.gennullrule()))
        return s

    def __repr__(self):
        s = '<Rewriter name='+ repr(self.name) +'\n'
        for typerw in self.typerws:
            s += '  '+ repr(typerw) +'\n'
        s += '\n'
        for callrw in self.callrws:
            s += '  '+ repr(callrw) +'\n'
        return s +'>'

    @classmethod
    def create(cls, name, rewrites):
        # shuffle rewrites by type, and sort them by decreasing specificity
        typerws = [ rw for rw in rewrites if isinstance(rw, TypeRewrite) ]
        typerws.sort(cmp=lambda x, y: cmp(y, x))
        callrws = [ rw for rw in rewrites if isinstance(rw, CallRewrite) ]
        callrws.sort(cmp=lambda x, y: cmp(y, x))
        return Rewriter(name, typerws, callrws)


class CxxQualifiedId:
    def __init__(self, id):
        self.id = id
        self.scopes = [ ]

    def qualify(self, id):
        self.scopes.append(self.id)
        self.id = id

    def iterparts(self):
        for scope in self.scopes:
            yield scope
        yield self.id

    def numparts(self):
        return len(self.scopes) + 1

    def __getitem__(self, i):
        np = self.numparts()
        if i < (np-1): return self.scopes[i]
        elif i == (np-1):    return self.id
        else:            raise IndexError

    def __cmp__(self, o):
        '''"<" means less specific'''
        npself, npo = self.numparts(), o.numparts()
        for i in xrange(min(npself, npo)):
            if self[i] != o[i]:
                if npself == npo:
                    return cmp(self[i], o[i])
                else:
                    return npself - npo
        return npself - npo
    def __hash__(self):
        return hash(repr(self))
    def __repr__(self):
        def accum(a, s):
            if not a:  return repr(s) 
            else:      return a +'::'+ repr(s)
        return reduce(accum, self.iterparts(), '')
    def __str__(self):
        def accum(a, s):
            if not a:  return str(s)
            else:      return a +'::'+ str(s)
        return reduce(accum, self.iterparts(), '')


class CxxType:
    POINTER = ('*',)
    REFERENCE = ('&',)
    ARRAY = ('[]',)

    def __init__(self, base_name):
        self.base_name = base_name # CxxQualifiedId
        self.qualifiers = [ ]
        self.terminal = False   # no more qualifiers allowed

    def pointer(self):
        assert not self.terminal
        self.qualifiers.append(self.POINTER)

    def reference(self):
        assert not self.terminal
        self.qualifiers.append(self.REFERENCE)
        self.terminal = True

    def array(self):
        assert not self.terminal
        self.qualifiers.append(self.ARRAY)

    def __cmp__(self, o): 
        '''"<" means less specific'''
        selfqlen, oqlen = len(self.qualifiers), len(o.qualifiers)
        if selfqlen != oqlen:
            return selfqlen - oqlen
        else:
            return cmp(self.base_name, o.base_name)
    def __hash__(self):  
        return hash(repr(self))
    def __repr__(self):
        r = repr(self.base_name)
        for q in self.qualifiers:  r += repr(q[0])
        return r
    def __repr__(self):
        r = str(self.base_name)
        for q in self.qualifiers:  r += str(q[0])
        return r


class CxxExprSelect:
    DOT = ('.',)
    PDOT = ('->',)

    def __init__(self, base, seltype, field):
        self.base = base
        self.seltype = seltype
        self.field = field

    def __cmp__(self, o):
        return cmp(self.base, o.base) or cmp(self.field, o.field) or cmp(self.seltype, o.seltype)
    def __hash__(self):
        return hash(repr(self))
    def __repr__(self):
        return '%r%s%r'% (self.base, self.seltype[0], self.field)


class CxxExprCall:
    def __init__(self, callable, args, 
                 ismethodcall=False, isnew=False, isdelete=False):
        self.callable = callable
        self.args = args
        self.ismethodcall = ismethodcall # so we can re-sugar if necessary
        self.isnew = isnew
        self.isdelete = isdelete

    def __cmp__(self, o):
        '''"<" means less specific'''
        if (o.isnew and not self.isnew
            or o.isdelete and not self.isdelete):
            return -1
        elif (self.isnew and not o.isnew
            or self.isdelete and not o.isdelete):
            return 1
        elif self.callable < o.callable:
            return -1
        elif o.callable > self.callable:
            return 1
        else:
            if (self.args is None) and (o.args is not None):
                return -1
            elif o.args is None:
                return 1
            elif len(self.args) != len(o.args):
                # not comparable, but makes the sort look nicer
                return len(self.args) - len(o.args)
            else:
                def countwcp(a, p):
                    if isinstance(p, Pattern) and p.iswildcard():
                        return 1 + a
                    else:
                        return a
                npatsself = reduce(countwcp, self.args, 0)
                npatso = reduce(countwcp, o.args, 0)
                return npatsself - npatso
    def __hash__(self):
        return hash(repr(self))
    def __repr__(self):
        if self.isdelete:
            return 'delete %r'% (self.args[0])
        if self.isnew:  pfx = 'new '
        else:           pfx =  ''
        return '%s%r(%r)'% (pfx, self.callable, self.args)
    def __str__(self):
        if self.isdelete:
            return 'delete %r'% (self.args[0])
        if self.isnew:  pfx = 'new '
        else:           pfx =  ''
        return '%s%s(%s)'% (pfx, self.callable, self.args)


# TODO this is a hack, and not a good AST design at all
class CxxExprAtom:
    def __init__(self, atom, isderef=False):
        self.atom = atom
        self.isderef = isderef

    def __cmp__(self, o):
        if self.atom == o.atom and self.isderef == o.isderef:
            return 0
        if o.isderef and not self.isderef:
            return -1
        if self.isderef and not o.isderef:
            return 1
        return cmp(self.atom, o.atom)

    def __hash__(self):
        return hash(repr(self))
    def __repr__(self):
        s = repr(self.atom)
        if self.isderef:
            return '*'+ s
        else:
            return s
    def __repr__(self):
        s = str(self.atom)
        if self.isderef:
            return '*'+ s
        else:
            return s


class TypeRewrite:
    def __init__(self, fromtype, totype):
        self.fromtype = fromtype
        self.totype = totype
        self.uid = None

    def gendecls(self, uid):
        self.uid = str(uid)
        return '''char %s[] = "%s";
char %s[] = "%s";
char %s[] = "%s";
char %s[] = "%s";'''% (
            self.base_name_decl(),
            str(self.fromtype.base_name),
            self.qualifiers_decl(),
            reduce(lambda a, q: a + q[0], self.fromtype.qualifiers, ''),
            self.to_base_name_decl(),
            str(self.totype.base_name),
            self.to_qualifiers_decl(),
            reduce(lambda a, q: a + q[0], self.totype.qualifiers, ''),
        )

    def genrule(self):
        return '{ %s, %s, %s, %s }'% (
            self.base_name_decl(), self.qualifiers_decl(), 
            self.to_base_name_decl(), self.to_qualifiers_decl())

    def base_name_decl(self):
        return self.declbase() + '_base_name'

    def qualifiers_decl(self):
        return self.declbase() + '_qualifiers'

    def to_base_name_decl(self):
        return self.declbase() + '_to_base_name'

    def to_qualifiers_decl(self):
        return self.declbase() + '_to_qualifiers'

    def declbase(self):
        return '_type_rule_'+ self.uid

    @classmethod
    def gennullrule(cls):
        return '{ 0, 0, 0, 0 }'

    def __cmp__(self, o):
        '''"<" here means "less specific".'''
        return cmp(self.fromtype, o.fromtype)
    def __hash__(self):
        return hash(repr(self))
    def __repr__(self):
        return '<TypeRewrite from=%r to=%r>'% (self.fromtype, self.totype)


class Pattern:
    def __init__(self, text, argno=None):
        self.text = text
        self.argno = argno

    def iswildcard(self):
        return self.argno is not None

    def __cmp__(self, o):
        '''Return true if |self| < |o|.  The "<" operator here means
        "less specific," so a wildcard pattern is less specific than a
        literal pattern.'''
        if self.iswildcard() == o.iswildcard():
            # alphanumeric sort looks nicer, but doesn't matter
            return cmp(self.text, o.text)
        elif self.iswildcard():
            return -1
        else:
            return 1
    def __hash__(self):
        return hash(repr(self))
    def __repr__(self):
        if self.iswildcard():
            tag = '$'+str(self.argno)
        else:
            tag = ''
        return '%s%r'% (tag, self.text)
    def __repr__(self):
        if self.iswildcard():
            tag = '$'+str(self.argno)
        else:
            tag = ''
        return '%s%s'% (tag, self.text)

class CallRewrite:
    def __init__(self, fromcall, tocall):
        self.fromcall = fromcall
        self.tocall = tocall

    def __cmp__(self, o):
        return cmp(self.fromcall, o.fromcall)
    def __hash__(self):
        return hash(repr(self))
    def __repr__(self):
        return '<CallRewrite from=%r to=%r>'% (self.fromcall, self.tocall)

    def gendecls(self, uid):
        self.uid = str(uid)     # uid to use for all this rule's decls
        self._suid = 0          # intra-rule uid for declaring strings

        def strdecl():
            t = self.prog_decl() +'_str%s'% (self._suid)
            self._suid += 1
            return t

        stringdecls = StringIO.StringIO()
        stringdecls.write('const char %s[] = "%s";\n'% (
                self.callable_decl(), self.fromcall.callable))
        progdecl = StringIO.StringIO()
        progdecl.write('WriteInstr %s[] = {\n'% (self.prog_decl()))

        def _emitdeclstr(decl):
            progdecl.write(
                '  { WriteInstr::EMIT_STRING, (uintptr_t) %s },\n'% (decl))
        def emitdot(): _emitdeclstr('dot')
        def emitarrow(): _emitdeclstr('arrow')
        def emitcomma(): _emitdeclstr('comma')
        def emitlparen(): _emitdeclstr('lparen')
        def emitrparen(): _emitdeclstr('rparen')
        def emitstar(): _emitdeclstr('star')
        def emitnew(): _emitdeclstr('_new')
        def emitdelete(): _emitdeclstr('_delete')

        def emitstring(str):
            sdecl = strdecl()
            stringdecls.write('const char %s[] = "%s";\n'% (sdecl, str))
            _emitdeclstr(sdecl)

        def emitcapture(i):
            progdecl.write(
                '  { WriteInstr::EMIT_CAPT_ARG, (uintptr_t) %s },\n'% (i))

        # generate the tocall callable
        if self.tocall.isnew:
            emitnew()
        elif self.tocall.isdelete:
            emitdelete()

        if isinstance(self.tocall.callable, CxxQualifiedId):
            emitstring(str(self.tocall.callable))
            # TODO support more expr types
        elif isinstance(self.tocall.callable, CxxExprSelect):
            # base obj
            if self.tocall.callable.base.iswildcard():
                emitcapture(self.tocall.callable.base.argno)
            else:
                emitstring(self.tocall.callable.base.text)

            # selector
            if CxxExprSelect.DOT == self.tocall.callable.seltype:
                emitdot()
            else:
                emitarrow()

            # field
            emitstring(self.tocall.callable.field)

        # generate the tocall args, if necessary
        if self.tocall.args is not None:
            # |None| means we replace the function name only
            if not self.tocall.isdelete: emitlparen()
            for i, arg in enumerate(self.tocall.args):
                if arg.text.isderef:
                    emitstar()
                if arg.iswildcard():
                    emitcapture(arg.argno)
                else:
                    emitstring(arg.text.atom)
                
                if i != (len(self.tocall.args) - 1):
                    emitcomma()
            if not self.tocall.isdelete: emitrparen()

        # done; emit the instruction saying so
        progdecl.write('  { WriteInstr::HALT, (uintptr_t) 0 } };')

        return stringdecls.getvalue() + progdecl.getvalue()

    def genrule(self):
        if self.fromcall.args is None:
            matchargs = -1 
        else:
            matchargs = len(self.fromcall.args)
        return '{ %s, %s, %s }'% (self.callable_decl(),
                                  matchargs,
                                  self.prog_decl())

    def callable_decl(self):
        return self.base_decl()+ '_callable'

    def prog_decl(self):
        return self.base_decl()+ '_prog'

    def base_decl(self):
        return '_call_rule_'+ self.uid

    @classmethod
    def gennullrule(cls):
        return '{ 0, 0, 0 }'

    @classmethod
    def patternize(cls, call, wcs=None):
        '''Process |call| into its constituent Patterns, and return them
        in a CxxExprCall.  If there's a token that might be either a 
        wildcard or a pattern, then we check |wcs|.  If it's in there, we
        consider the argument to be a wildcard.
        If |wcs| is None, then we assume ambiguous patterns are wildcards.'''
        pats = { }

        def makepat(*args):
            pat = Pattern(*args)
            if pat.iswildcard():
                if isinstance(pat.text, str):
                    key = pat.text
                elif isinstance(pat.text, CxxExprAtom):
                    key = pat.text.atom
                else:
                    raise Error, 'not reached'
                pats[key] = pat.argno
            return pat

        if isinstance(call, CxxExprCall):
            callable, args = call.callable, call.args 
        elif call.isdelete:
            callable, args = None, call.args

        # patternize callable
        callablepat = None
        if isinstance(callable, CxxQualifiedId):
            # no wildcard: it's specifying a function
            for part in callable:
                if callablepat is None:
                    callablepat = CxxQualifiedId(makepat(part))
                else:  
                    callablepat.qualify(makepat(part))
        elif isinstance(callable, CxxExprSelect):
            # ambiguity: is the base of the selection a wildcard or not?
            # note: this assumes that foo->Bar() can't be a |fromcall|
            # TODO support more complicated expressions
            base = callable.base
            if wcs is None:
                x = None
            else:
                x = wcs.get(base, None)
            callablepat = CxxExprSelect(
                makepat(base, x),
                callable.seltype,
                callable.field)

        # patternize args
        argspats = None
        if args is not None:
            # ambiguity: for each argument, we don't know if it's specifying
            # a wildcard or a name
            argspats = [ ]
            for i, arg in enumerate(args):
                if wcs is None:
                    x = i 
                else:
                    x = wcs.get(arg.atom, None)
                argspats.append(makepat(arg, x))

        return (CxxExprCall(callablepat, argspats, 
                           ismethodcall=call.ismethodcall, 
                           isnew=call.isnew,
                           isdelete=call.isdelete),
                pats)

    @classmethod
    def create(cls, fromcall, tocall):
        fromcall, fcpats, = cls.patternize(fromcall)
        tocall, _ = cls.patternize(tocall, fcpats)
        return CallRewrite(fromcall, tocall)
